import sys
import os
import pysam
import functools

import pybedtools
from Bio.Seq import reverse_complement
from Bio import SeqIO
from numpy import *

assembly = sys.argv[1]

if assembly in ('hg19', 'hg38'):
    organism = 'human'
elif assembly=='mm9':
    organism = 'mouse'
elif assembly=='rn4':
    organism = 'rat'
elif assembly=='canFam2':
    organism = 'dog'
elif assembly=='galGal4':
    organism = 'chicken'

filenames = ["/osc-fs_home/mdehoon/Data/NCBI/%s/trnas.bed" % assembly,
             "/osc-fs_home/mdehoon/Data/UCSC/tRNA/%s/tRNA.gff" % assembly,
             "/osc-fs_home/mdehoon/Data/RepeatMasker/%s/tRNA.gff" % assembly,
            ]

rows = []
for filename in filenames:
    for row in pybedtools.BedTool(filename):
        rows.append(row)

def compare_chromosomes(chromosome1, chromosome2):
    if chromosome1==chromosome2:
        return 0
    if "hap" in chromosome1 and "hap" in chromosome2:
        chromosome1, name1, hap1 = chromosome1.split("_")
        chromosome2, name2, hap2 = chromosome2.split("_")
        if chromosome1!=chromosome2:
            return compare_chromosomes(chromosome1, chromosome2)
        assert hap1.startswith("hap")
        assert hap2.startswith("hap")
        hap1 = int(hap1[3:])
        hap2 = int(hap2[3:])
        if hap1 < hap2:
            return -1
        if hap1 > hap2:
            return +1
        raise Exception("Should not get here")
    if "hap" in chromosome1:
        return +1
    if "hap" in chromosome2:
        return -1
    for chromosome in ("chrM", "chrY", "chrX", "chrZ", "chrW"):
        if chromosome1==chromosome:
            return +1
        if chromosome2==chromosome:
            return -1
    if "random" in chromosome1 and "random" in chromosome2:
        assert chromosome1.endswith("_random")
        assert chromosome2.endswith("_random")
        chromosome1 = chromosome1[:-7]
        chromosome2 = chromosome2[:-7]
        return compare_chromosomes(chromosome1, chromosome2)
    if "random" in chromosome1:
        return +1
    if "random" in chromosome2:
        return -1
    if chromosome1.startswith("chrUn") and chromosome2.startswith("chrUn"):
        if chromosome1 < chromosome2:
            return -1
        if chromosome1 > chromosome2:
            return +1
        raise Exception("Should not get here")
    if chromosome1.startswith("chrUn"):
        return +1
    if chromosome2.startswith("chrUn"):
        return -1
    try:
        chromosome1 = int(chromosome1[3:])
    except ValueError:
        if chromosome1 < chromosome2:
            return -1
        if chromosome1 > chromosome2:
            return +1
    try:
        chromosome2 = int(chromosome2[3:])
    except ValueError:
        if chromosome1 < chromosome2:
            return -1
        if chromosome1 > chromosome2:
            return +1
    if chromosome1 < chromosome2:
        return -1
    if chromosome1 > chromosome2:
        return +1
    raise Exception("Should not get here")

def compare(row1, row2):
    chromosome1 = row1.chrom
    chromosome2 = row2.chrom
    difference = compare_chromosomes(chromosome1, chromosome2)
    if difference!=0:
        return difference
    start1 = row1.start
    start2 = row2.start
    if start1 < start2:
        return -1
    if start1 > start2:
        return +1
    end1 = row1.end
    end2 = row2.end
    if end1 < end2:
        return -1
    if end1 > end2:
        return +1
    if row1.file_type == 'bed':
        source1 = 'NCBI'
    else:
        source1 = row1.fields[1]
    if row2.file_type == 'bed':
        source2 = 'NCBI'
    else:
        source2 = row2.fields[1]
    if source1 < source2:
        return -1
    if source1 > source2:
        return +1
    return 0

rows.sort(key=functools.cmp_to_key(compare))

def overlap(row1, row2, threshold=30):
    # require an overlap by at least threshold nucleotides
    if row2 is None:
        return False
    chromosome1 = row1.chrom
    chromosome2 = row2.chrom
    if chromosome1!=chromosome2:
        return False
    strand1 = row1.strand
    strand2 = row2.strand
    if strand1!=strand2:
        return False
    start1 = row1.start
    end1 = row1.end
    start2 = row2.start
    end2 = row2.end
    end = min(end1,end2)
    start = max(start1,start2)
    if end-start > threshold:
        return True
    else:
        return False

current = None
groups = []
for row in rows:
    if overlap(row, current):
       group.append(row)
    else:
       current = row
       group = [row]
       groups.append(group)

def read_genome(chromosome):
    path = "/osc-fs_home/scratch/mdehoon/Data/Genomes/%s/%s.fa" % (assembly, chromosome)
    genome = SeqIO.read(path, 'fasta')
    genome = str(genome.seq)
    genome = genome.upper()
    return genome

counts = {}
directory = "/analysisdata/mirrors/F5_file_current/UPDATE_029/f5pipeline/"
if assembly=='galGal4':
    directory = "/osc-fs_home/mdehoon/Data/Fantom5/sRNA/UPDATE_023/FixedBamFiles/"

subdirectories = os.listdir(directory)
for subdirectory in subdirectories:
    terms = subdirectory.split(".")
    if len(terms)!=3:
        continue
    if terms[0]!=organism:
        continue
    if terms[2]!='sRNA':
        continue
    subdirectory = os.path.join(directory, subdirectory)
    filenames = os.listdir(subdirectory)
    for filename in filenames:
        rootname, extension = os.path.splitext(filename)
        if extension!='.bam':
            continue
        path = os.path.join(subdirectory, filename)
        print("Reading", path)
        stream = pysam.AlignmentFile(path, "rb")
        for line in stream:
            sequence = line.seq
            if line.flag & 16:
                sequence = reverse_complement(sequence)
            tag = sequence[-13:]
            counts[tag] = counts.get(tag, 0) + 1
 
def parse_repeatmasker_name(name):
    assert name.startswith("tRNA|")
    name = name[5:]
    assert name.startswith("tRNA-")
    name = name[5:]
    if name.endswith("_"):
        name = name[:-1]
    if name.endswith("-i"):
        name = name[:-2]
    if name.endswith("(m)"):
        name = name[:-3]
    if '-' in name:
        aminoacid, codon = name.split("-")
    else:
        aminoacid = name
        codon = '???'
    if aminoacid=="SeC(e)":
        aminoacid = "Sec"
    name = aminoacid + codon
    return name
 
def parse_ucsc_name(name):
    assert name.startswith("tRNA|")
    name = name[5:]
    if name=="Undet???":
        name = "Unk???"
    elif name.startswith("Unknown"):
        name = "Unk" + name[7:]
    elif name.startswith("Pseudo"):
        name = "Pse" + name[6:]
    assert len(name)==6
    aminoacid, anticodon = name[:3], name[3:]
    codon = reverse_complement(anticodon)
    if aminoacid=='Sup':
        if codon=='TAA':
            aminoacid = 'Och'
        elif codon=='TAG':
            aminoacid = 'Amb'
        elif codon=='TGA':
            aminoacid = 'Opl'
        else:
            raise Exception("Suppressor with inconsistent anticodon")
    elif aminoacid=="SeC":
        aminoacid = "Sec"
    name = aminoacid + codon
    return name

def parse_ncbi_name(name):
    geneid, genename, name = name.split(":")
    if len(name) == 6:
        aminoacid, codon = name[:3], name[3:]
    elif len(name) == 7:
        aminoacid, codon = name[:4], name[4:]
        assert aminoacid == 'iMet'
    else:
        raise Exception("Cannot understand %s" % name)
    if codon=='Unk':
        codon = '???'
    else:
        codon = codon.replace('U', 'T')
    name = aminoacid + codon
    return name


def parse_group(group):
    starts = []
    ends = []
    names = set()
    strand = None
    chromosome = None
    exons = None
    for row in group:
        if chromosome:
            assert chromosome==row.chrom
        else:
            chromosome = row.chrom
        if strand:
            assert strand==row.strand
        else:
            strand = row.strand
        if strand=='+':
            start = row.start
            end = row.end
        elif strand=='-':
            start = row.end
            end = row.start
        starts.append(start)
        ends.append(end)
        if row.file_type == 'bed':
            source = 'NCBI'
            name = row.name
            assert exons is None
            blockCount = int(row.fields[9])
            blockSizes = row.fields[10].rstrip(",").split(",")
            blockStarts = row.fields[11].rstrip(",").split(",")
            assert len(blockSizes) == blockCount
            assert len(blockStarts) == blockCount
            assert blockStarts[0] == '0'
            exons = []
            blockEnd = 0
            for i in range(blockCount):
                blockSize = int(blockSizes[i]) 
                blockStart = int(blockStarts[i]) 
                assert blockStart >= blockEnd
                blockEnd = blockStart + blockSize
                exon = [row.start + blockStart, row.start + blockEnd]
                exons.append(exon)
        else:
            source = row.fields[1]
            name = row.fields[2]
        names.add((source, name))
    return chromosome, strand, starts, ends, names, exons

def make_name(source_names, skip=0):
    names = []
    for source, name in source_names:
        if source=='RepeatMasker':
            if skip > 0:
                continue
            name = parse_repeatmasker_name(name)
        elif source=='UCSC':
            if skip > 1:
                continue
            name = parse_ucsc_name(name)
        elif source=='NCBI':
            name = parse_ncbi_name(name)
        else:
            raise Exception("Unknown source")
        names.append(name)
    if len(names)==0:
        name = "??????"
        return name
    if len(names)==1:
        name = names.pop()
        return name
    aminoacids = set()
    codons = set()
    for name in names:
        if len(name) == 6:
            aminoacid, codon = name[:3], name[3:]
        elif len(name) == 7:
            aminoacid, codon = name[:4], name[3:]
            assert aminoacid == 'iMet'
        else:
            raise Exception("Cannot understand %s" % name)
        aminoacids.add(aminoacid)
        codons.add(codon)
    if len(aminoacids)!=1:
        if '???' in aminoacids:
            aminoacids.remove("???")
        if len(aminoacids)!=1:
            assert skip < 2
            return make_name(source_names, skip+1)
    aminoacid = aminoacids.pop()
    if len(codons)!=1:
        if "???" in codons:
            codons.remove("???")
        if len(codons)==1:
            codon = codons.pop()
        else:
            degenerate_codons = set()
            nondegenerate_codons = set()
            for codon in codons:
                nucleotide = codon[2]
                if nucleotide not in "ACGT":
                    degenerate_codons.add(codon)
                else:
                    nondegenerate_codons.add(codon)
                for nucleotide in codon[:2]:
                    assert nucleotide in "ACGT"
            if len(nondegenerate_codons)!=1:
                assert skip < 2
                return make_name(source_names, skip+1)
            assert len(degenerate_codons)==1
            codon = nondegenerate_codons.pop()
            degenerate_codon = degenerate_codons.pop()
            assert codon[:2]==degenerate_codon[:2]
            degenerate_nucleotide = degenerate_codon[2]
            nucleotide = codon[2]
            if degenerate_nucleotide=='R':
                if not nucleotide in "AG":
                    assert skip < 2
                    return make_name(source_names, skip+1)
            elif degenerate_nucleotide=='Y':
                if not nucleotide in "CT":
                    assert skip < 2
                    return make_name(source_names, skip+1)
            elif degenerate_nucleotide=='N':
                pass
            else:
                raise Exception
    name = aminoacid + codon
    return name


numbers = {}
current = None
filename = "mergedtrnas.bed"
print("Writing", filename)
output = open(filename, 'w')
for group in groups:
    chromosome, strand, starts, ends, names, exons = parse_group(group)
    if strand=='+':
        start = min(starts)
    elif strand=='-':
        start = max(starts)
    end = int(round(median(ends)))
    if chromosome!=current:
        genome = read_genome(chromosome)
        current = chromosome
    best = None
    maximum = 0
    positions = range(min(ends)-5, max(ends)+5)
    tagcounts = zeros(len(positions), int)
    for i, position in enumerate(positions):
        if strand=='+':
            sequence = genome[position-10:position]
        if strand=='-':
            sequence = genome[position:position+10]
            sequence = reverse_complement(sequence)
        sequence += 'CCA'
        tagcounts[i] = counts.get(sequence, 0)
    i = argmax(tagcounts)
    end = positions[i]
    if not end in ends:
        if tagcounts[i] < max(0.9 * sum(tagcounts), 1000):
            maximum = 0
            for position in ends:
                i = positions.index(position)
                count = tagcounts[i]
                if count > maximum:
                    maximum = count
                    end = position
            if maximum == 0:
                if exons is not None:  # NCBI result
                    if strand == '+':
                        end = exons[-1][1]
                    elif strand == '-':
                        end = exons[0][0]
                else:
                    assert len(names) == len(ends)
                    if strand == '+':
                        end = max(ends)
                    elif strand == '-':
                        end = min(ends)
    if strand == '-':
        start, end = end, start
    assert end > start
    name = make_name(names)
    number = numbers.get(name, 1)
    numbers[name] = number+1
    name += str(number)
    name = "tRNA|%s" % name
    thickStart = start
    thickEnd = start
    if exons is None:
        blockSizes, blockStarts = [[end-start], [0]]
    else:
        exons[0][0] = start
        exons[-1][1] = end
        blockSizes = []
        blockStarts = []
        for exon in exons:
            exon_start, exon_end = exon
            blockStart = exon_start - start
            blockSize = exon_end - exon_start
            blockStarts.append(blockStart)
            blockSizes.append(blockSize)
    blockCount = len(blockSizes)
    assert blockStarts[0] == 0
    blockStarts = ",".join([str(blockStart) for blockStart in blockStarts]) + ","
    blockSizes = ",".join([str(blockSize) for blockSize in blockSizes]) + ","
    if strand == '+':
        if end not in ends:
            print("Changed %s to %s:%d" % (name, chromosome, end))
    elif strand == '-':
        if start not in ends:
            print("Changed %s to %s:%d" % (name, chromosome, start))
    row = (chromosome, start, end, name, strand, thickStart, thickEnd, blockCount, blockSizes, blockStarts)
    line = "%s\t%d\t%d\t%s\t.\t%s\t%d\t%d\t.\t%d\t%s\t%s\n" % row
    output.write(line)
output.close()
